{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1967b8a4",
   "metadata": {},
   "source": [
    "# 导入包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dcbdab21",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-06T23:54:59.858209Z",
     "start_time": "2021-09-06T23:54:57.536076Z"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c352b54",
   "metadata": {},
   "source": [
    "# 生成数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a6880411",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:00:38.022551Z",
     "start_time": "2021-09-07T00:00:38.004550Z"
    }
   },
   "outputs": [],
   "source": [
    "def synthetic_data(w, b, num_examples):  #@save\n",
    "    \"\"\"生成 y = Xw + b + 噪声。\"\"\"\n",
    "    X = torch.normal(0, 1, (num_examples, len(w)))\n",
    "    y = torch.matmul(X, w) + b\n",
    "    y += torch.normal(0, 0.01, y.shape)\n",
    "    return X, y.reshape((-1, 1))\n",
    "\n",
    "true_w = torch.tensor([2, -3.4])\n",
    "true_b = 4.2\n",
    "features, labels = synthetic_data(true_w, true_b, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "252771a0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:04:55.535280Z",
     "start_time": "2021-09-07T00:04:55.522279Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features: tensor([ 0.3447, -0.8437]) \n",
      "label: tensor([7.7599])\n"
     ]
    }
   ],
   "source": [
    "#features中的每一行都包含一个二维数据样本，labels中的每一行都包含一维标签值（一个标量）。\n",
    "print('features:', features[0],'\\nlabel:', labels[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc6b9cf1",
   "metadata": {},
   "source": [
    "## 分步骤看下各项数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "73cd7b72",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-06T23:55:44.021735Z",
     "start_time": "2021-09-06T23:55:43.130684Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.3528,  0.6179],\n",
       "        [-0.0931, -0.7613],\n",
       "        [-0.8542, -1.6903],\n",
       "        [-1.1490, -0.4645],\n",
       "        [-1.0691,  0.0077],\n",
       "        [ 0.4251, -0.5353],\n",
       "        [ 0.2390,  0.1696],\n",
       "        [ 0.8460,  0.4537],\n",
       "        [ 0.5301,  0.5635],\n",
       "        [ 0.1652,  2.5602]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.normal(0, 1, (10, 2))\n",
    "X      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "610e92a0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-06T23:58:29.742214Z",
     "start_time": "2021-09-06T23:58:29.729213Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.6063,  6.6024,  8.2385,  3.4812,  2.0355,  6.8701,  4.1015,  4.3494,\n",
       "         3.3442, -4.1744])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w= torch.tensor([2, -3.4])\n",
    "b=4.2\n",
    "y = torch.matmul(X, w) + b \n",
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a64ea0ac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-06T23:59:06.186298Z",
     "start_time": "2021-09-06T23:59:06.178298Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 2])\n",
      "torch.Size([2])\n"
     ]
    }
   ],
   "source": [
    "print(X.shape)\n",
    "print(w.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b08ef098",
   "metadata": {},
   "source": [
    "## 数据可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6aa7c9e4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:02:41.394608Z",
     "start_time": "2021-09-07T00:02:40.569560Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA1FklEQVR4nO2df3RU53nnvy+CSZGUBUkIFYNAIMlwZBerRrYpwcYYnMYtxe2e4jQ53VDv7sHes3FZ12c38cZt2qxbejbrUNXbU5tt7ZLTJq29bRMvG7YBBQOCgC1cQu0JQjNIGDAVo5EgSEM9SLz7x+i9eu+de2fund935vs5x0fozr3vfe9gvu9zn/f5IaSUIIQQ4l9mFXsChBBCsoNCTgghPodCTgghPodCTgghPodCTgghPmd2MW66YMEC2dLSUoxbE0KIbzl16tSIlLLRerwoQt7S0oK+vr5i3JoQQnyLEOKC3XG6VgghxOdQyAkhxOdQyAkhxOdQyAkhxOdQyAkhxOdQyAkhxOdQyAkhxOdQyEuA0Yk4Xj0cxuhEvNhTIYT4EF8JebkK3pt9F7Fr/1m82Xex2FMhhPiQomR2ZooSPAB4akNrkWeTO7Z1NZt+EkKIF3wl5OUqePU1gbJamAghhcVXQk7BI4SQZFz7yIUQrwkhrgoh3teO/a4Q4rIQ4vT0f7+Qn2kSQghxwstm518A+IzN8d1Sys7p/76Xm2kRQghxi2shl1IeATCax7kQQgjJgFyEH35RCHFm2vVS53SSEGKHEKJPCNEXiURycFtCCCFA9kL+pwBaAXQCuALgJacTpZR7pJRdUsquxsakBhcVQ7nGwhNCikdWQi6lHJZSTkkpbwP4XwDuz820yhcm/xBCck1W4YdCiEVSyivTv/4KgPdTnU/KNxaeEFI8XAu5EOLbAB4GsEAIcQnAVwE8LIToBCABDAF4KvdT9C+jE3G82XcR27qaUV8TAMBYeEJI7nEt5FLKz9kc/vMczqXsKNeSAoSQ0sJXmZ1+g24UQkghoJDnEbpRCCGFwFdlbAkhhCRDISeEEJ9DISeEEJ9DISeEEJ9DIS8yxUrZZ6kAQsoHCnmRcUrZz1Ro3V7HUgGElA8MPywyTrHmmSYTub2OMe6ElA8U8iLjFGvuVWhVOYDNHU2urmOMOyHlA10rJYoSWlWjJR3KEj8YHDYE2qtrhn5zQvwJLfIywWrBZ+KaYW0YQvwJhdzHWKsr6uLrxTXj1S1DCCktKOQ+JpUFbecDtyurm24cQkjpQyHPAidhLBReN0SdBJsRLIT4m4oV8lyIsJMwFkrgvUaeOAk2I1gI8TcVG7XiJiEmXRTHtq5mPP/YKscYcK/JNk73y1U0iddIGEKIP6hYi9yNOyGd7zhXMeDp7pdPH3ax3UOEkOypWCF3407IVJAzdVU43S+fPmxudBLifypWyNNRSEtVv5cbyz+Xc+NGJyH+p2J95OkoZFEpr/fK5dzoNyfE/9Aid6CQlqrXe/nRiqYvnpD8QYvcgXxbqnokitd7FcuKziZ6hmVzCckftMiLhB83GbOZsx/fIgjxCxRyj+TKRVAMYct27tnMWb1FKKueLhZCcgddKx7JhYvAq6DmKiEo27nnwqWTbg4spUuId2iReyQXlrRXF0Wu3DDbupoRi08hFp80fPOjE3HsPT4EQGL7uuWOIu128Ul3Xrrvz48uJ0KKDYXcI7moS1KsKJX6mgCqA1XYtf8sqgOz8dSGVrzZdxHdPQMAgDOXruOlJzptBditwO49PojunhBi8Uk8++hK2zmwBR0huYVCXgS8LgZez09lFVuFUlnpfUOjONQfwZt9F4176eM4CWzyvcT0JwKZwAJehHiHQl6GeKlTXl8TwLOP3mkSZKdx7ATWes72dS2oDlTRoiakgFDIy5BM3BN2lrCbcaznlItFzQQm4icYtVKG5CphyM04dudYI09GJ+LYfeAcdh/ox+hE3BeRKUxgIn6CFvk0+bTAKs26s7pb9A3V6kDif7lSi0yx/h1x05X4CddCLoR4DcAWAFellHdPH6sH8DcAWgAMAXhCSjmW+2nmn3yGvfklpC5fIYZqQxWQJmEsJZG0/h2Vi4uIVAZeLPK/APA/AXxTO/ZlAD1Syj8UQnx5+vcv5W56hSOfFlg+xraKaT5b11nv57XhhtpQ1ccqNexi7AnxC66FXEp5RAjRYjn8OICHp/+8F8Db8KmQ59MCy8fYdu6LXfvPIhafRHVgdkaCnmrB0e/nZWFSC8DmjiYcDA6nXQiK5Yayi7EnxC9k6yNvklJemf7zPwNoynI84hI79wUARCfi6O45i1h8yrCC3YpjqgYW+n28LEx7jw+hu2cARwdG0BsasZ2zTjHdUPSLE7+Ss81OKaUUQkinz4UQOwDsAIClS5fm6rYVi5374qkNrdh9oH/6yMxfxYy1PmXEeNfXBBCOjOPFfUG8sKUDddUBUxz5m30XEYtPGZuUTnHkTqhF4GZ8EgDQseiTeLB9QdqFYHNHE06cj2JzR+FtAvrFiV/JVsiHhRCLpJRXhBCLAFx1OlFKuQfAHgDo6upyFHySmnTW9dbOxThz6Tq2di42jilxjsUnTdbui/uCONQfARDE2hUNxmdAIqpk56Y2PP/YKk8WqpqfWgR2bmo3xki3gbr3+BBOXRhFbyiKtSuG0bqh1vV9CalkshXytwBsB/CH0z+/m/WMSErSuR4OBodxqD9iEkK9hKzynwPAC1s6AMxY5IDZrWAVXzcuGjU/fRFw4+vWQxRbG2s8WeSVFt5JiBUv4YffRmJjc4EQ4hKAryIh4G8IIf4dgAsAnsjHJMkMqfy4oxNxxOJT2LmpzfZzq+ugtbEWrz95v/G7/pmbdPx087MTVSfR3dzRhKMDEdyakjg5OIqDwZmFKJVQj07E8dwbp6ffLEo7vJOQfOElauVzDh9tytFcSJYoq/b5x1Z5ElG3uNkMTOdndloMDgaH0RuKYuemNqxd0YBYfBLhyDgOBoeTfPXW8Q71R7BxZSM3KUnFwsxOn5HKKs53re98lvDVj6t5nrl0HYf6I46++sQbyCR2bmrH9nUtdKuQioVC7jNSiXWp1Pq2s/wTm5mDAAS2dt6R9Lk+dzW/+1rqASQ2cFsba5PGj8Un0d0TcnwDIaRSYNGsIlGMwlGZFNPKZJ52BacSbp8QunsG8NXvfpCyIJWa57vTNdIPBodtxweE56gaQsoRCnmRyLS6Xj6r8tmJdqr7OYn8tq5m7NzUhlh8yvhsc0cTWhqqASRiyp0EWB8zMU67kTZvHr8dqnYLrXFS6VDIi8S2ruaMrEn9ulxb9XainWqeTiKfSHefje6eAbzZdxGjE3G8uC+IoWgM69sWYK6WlASYxVsfU6XNd/eETPdwOq7I5HtJd40fSu+SyoU+8iKR6cahft2rh8M5TWe386Gnmmcqn7t18/JQfwSfam0AINHdEzLVM0lVx8XN5qiVTDZ11TUnzkdt+5b6pYIlqVCklAX/b82aNZJkT3T8Y/nK2yEZHf84o8/dnpPqXC/3+Mb3++WyL+2Tn9/zQ/mN7/cb11jHyPY+oas30o5nd+1vvHZSLvvSPvnK2yFXz05IoQHQJ200lRa5j0lV5Kq+JuDKivRSidDu3HT3sOsFeurCKLp7BlAdqLKt/b33+CC6e0KIxSfx7KMrTfdxsphTzcXN91BfE8BLT3QmzVX/nJY4KVUo5GWEVbAy6blpN54qjavS5u2aQyifvTWs0CrK1YEq9IaiaRJ4hOVnYvwT56M41B/Bc2+cxgtbOnAwOGwqj5uoKT5pbLKqObgtxEWxJn6FQl5GZNsI2andWSw+5WjR6vfYfeAcunsGTCV0raKsRDVR5yXh57dGnmxf12JUadTv89ITnUY6/oejfQhHJnDkXATHwlFtoZg9XVO8ypiXXf0Zr98F67mQUoZCXkZkYlHqVjxg7qVpLrZV5WhBz5SsnZo+MlPccmvnHThz6Rq2dt4BYEZUgSBWL5lvm3rv9BxWMd+4shHN9dU4Fo7iZvw2gJlFbHNHk7FIZJII5dS4wzpXQkoBCnmFYWd1K3eEEltd8DKteKh46/RlHOqPYPWSy3j20ZUmF0l88jZ2bmq3Tb1XXYXeOn0ZgDBS8JWYqyzRm7emFw+B6VrsiczRmRK9M24mLxa22+gZL98TIfmCQl5h2DUZtnNHOJ2vowsu4FTxMNnfvXrJPNyakugNjeChOxsdQ/2U4AMwzU2fs1o8VLo+AOw78xHCkQmTH14f84UtHUlCb0Utcro4Z1IMjJBCQCGvMOwsS7fx4AprrRMAhgtG93knkmekUdQKmEnV37mpDQ+2L8DmjibsPnAOgMT2dcu1t4Qp3IxPormuGudHJnBfS71pbOu89h4fxI4HVyB45SfoDY1g48pGU3SL/iYABF1VTLQTZyffud1GsB203Ek+oJBXGHaWZSpr084ynXGltJtcKXZ+ZWtRK92HfTA4jLdOf2T4yVWS0Ezm5gA2rmxEb2gEc6qEyYK2Jkap+/zx537WEFZ9znp4oSrG9cKWDmNeets7VaDLbhGzhkF6tcRpuZN8QCEnaXEKa1RirEjnV9atUTXmvUvn4455P4UFtQFExz82wgY3dzThyLkImuuqsXNTGzbcudC4p45eylZ3gehZr/pCpD6zRrHobe/0eHK7UsHKstdjzt1upBaqAiWpLCjkJC1OYY1KLFWcuVX4rJb+K2+HsOfoIKLjH+Pph9tMPvCPrv8Lzlz+CRpqP4GnNrTiYHAYx8JRHAtH8fxjq3D43NXpDdL38bVfvtuIHbez+vW5bu5oSuogZBdXrre9S2U1WxOHvEYKMVad5AMKeYXixVfrJD5u4sx1glduGD+VIP7RgXM41H8VP9fagEXzfsq0aMTik7gZv43oRBw/ujgGADgWjhrW8xt9F/Hbv9iBjSsbDVHWn0u3vnV/uF1cud72rq4ruX+p9fuwupv0e+tJSvSDk0JAIa9QcuGr1ePMAZmUUalINJUYwh3zfgotDdX4relkofqaABbXzcXFsZv49cbapA3TZx9daXKRrG9rwJplddjauRiDI+8iHJnA7/6fDzAUjRminC679dXDYWzuaLLNAM30+1OirjZ/9TcNWt+kEFDIKxS71Hq3lqTVmtfDAc9cumZEi8xEt8z03ASAd4dGce+yOtM8VAKPfq6yync8tAJz58wyoloA4NMdTdhzdBAPtS/A5+6vdvTL21WLjMWnjDZyuQi5VGUIHlhej52b2rC1czHWrhhO6Qdn9ArJJRTyCsVO4NxaklahUxuODyyvw6H+CPYeH8Szj640JQrt3NSOsYmPcX4kZvJNj8XiOHE+iuh4HHuOnjclFen+b6u1/vTDbWio/YQpYiadcMbiU9i5qQ2ATBl+6L1EbiJG/uTgKB5ZtRCtjbVpSwF4eSOi6JN0UMiJySpOZ0nq56ufr7wdxp6j53Hv0vnTZ4ik8+prAnj1cBjfPPEhDgaTo0VuTd02BNwaquiU1KPQC3OpNwPALJDqHGUxn7l03RR+qONUVTLhkkksBvp3pGLkb8anjG5G+rh2Qpwqooe10IlXKOTEJFxKYFMJi1XogleuAwDmVJl7aFrPs7N0VbTIM4+0492hUeO4k/gdHYjgUH8EsfgZnBwcQyw+CT171NmanjlHLxuwfd1y1yUI1BuLNUKmviaAZx+903izURE81lBLYEaIrd9NKrFmyCJJB4Wc2OLFCvy9x+9OSqbR0UXZGu2hokWscd/WkEE1RseieegNRfHRtX8BANy8dRtz51Rh56Y2ox6L3Xz1ioqvHA4DAE6cTywc1uxUq7BbY9r1glx2IY9W8U5nfetuHzuxZsgiSQeFvELxWjQqFXront3YThUWdVG3i/vWfdh7jw+hu2cAOx5agY0rG43P586Zhe6eAZOVnIiSSRTV0ottKTGcOyfRqvbk4CjWrmhwzE7Vo1H0mPaDwWHbRU6/hyozEB3/GHuPDxnzsN5jJgN2IMnSJ8QtFPIKJZ3FnY0V6KbBhZ3VGotP4qvf/cC2VooqjTt3zqykTj6q6YWyktUmaeKzqiRLe/u65dNjJiolOmWn6qUIdjy4HMErN7C5owl11Yks1KMDI9jc0YTWxlqHSJ4qU7SOeiNIlwFLiFco5BVKvsTDzk1gXRSUsEbH41jf1mDURVHimyziwPZ1yw3fs53v/Te//R56Q4kGE1s7F+PoQAQdi+bZWtpPbWg1WsilamBtFfXe0CAOBofx1IZWDFwdR29oBC/uC+L1J++3XRi3dTUjOh5H8Mp13NTCKp/a0JpkydslGBHiFgp5hZKJxe2ma47uJgDsOwCp8/YcPQ8ARux6LD6JsYlbCEfGMRaLm1wlugWuxlTjRCfi6A1Fp0cWOBgcRm8oijlVs4z7ObV7U+Pc11KPJ19/B+1Nn8SeI+dNn+l/VnH3jbWfQF31HGz/uZakz/U5zw3MQm8oijXL6pJqtVu/D0amkEyhkBPXuOma47ThZxUntYF41+IZq7k6MBunLlwz0vCV393Jx67+/KnWBgDAp1objFBAvbCVqt1i1+5NLWhPvv6O0exCCa6136h6k9A3Yv+sdxDHz0fxweXr+Nov353kB9crRKaytFO9ITGOnKSDQk5c48a3a93wA2Ab5aGKYqnGEsrFsePB5ZhTJYw4ceWq2fHgcscuRnYZqbof3Voh0Q69aFZddWKDdmziFgAYbeSAhEAf6o9gyfy5WFw3Fx2LZiz4F/fNVE5UZQDUXDOtZwMkN7AmxAqFnLhGr62ihDmVG8CupKw63+rqsCYPKZSrZn1bA3pDUdyMT6GhNmAaH0jEv+vzAhLFvF55O4wPPrqOY+EoWhtrsLXzDlu3UF11AGtXNBgivmv/WbQ0VAMA5gaqAMyUzFVzaW+qxWfvXwoIgQ8uX0+qnKh3XsrOB57cZckJWu+VCYWceMarP1eF4ulZj3oTZrWxmarCYnQ84QcPXrlu+MOt4qjmFYtPGrVUFHXVcxCOTCRtTsbiU6gOVCXVeDlyLoJj4ajJXaM2ZNe3LcCnWhuMpKKGmgBe/vy9CeHsSLh17mupx+FzEeNNQlnVbr8zhV2XJbtzUiUfkfKHQk484zXiRYXiqaxHFZJ45FwiS/OZb703I4Q2xOJTUGKmwgXtREu3xA/1R7C+rQG3pm7j5OAYfmn1Ilwcu4kXtnQk1V3Ztf8s1rc1mCJtxLTxq37OWOML0Bsawc5N7XjozsakEr4zCxSMWPdD/ZGkbkpuLWeneuvWc5ySj0hlQCEnnvEa8WLno66vCaCrpd5ItNl7fNAUXqiu0zcWn39slakglVW0dNePcmcASBLM3QfOobtnwLBylfX+YPuMv743lHDF9IaiWgXHEHY8uAIPti8wRe7o99KjYG5NSTTXzU3qWerFck61z2A9xy40k1QGORFyIcQQgBsApgBMSim7cjEuKQ+crMoZV4EEIEy9MJWr5FB/BHXVc/BLqxclWZl2omVn6SYLmzR+qgYXr7wdxtGBCDZ3NBnx3z+6NIZNq5qMCBYAmBuYZSwWf/C9HydFq6g57T7Qj97QCHpDMJ7bWoZA/+mEijG3lixI9z24wS4DlviTXFrkG6WUIzkcj5QgXjfTUtURUcWm1HlnLl0zhQxu62rGt05+iAujMZwfmXB1P2tzZHVMn6/K7Lx56zZ2H+jH9nXLMXD1BnpDM2GPA1dv4OTgGKoDs1FfEzAlJKkxVbTK73znfbQ21uLIQATfeKJzutZ6wiezvq3BNsrHi/iqBc2p7G6m2GXAEn9C1wrxRCZd493UEbH2wlTHfv6un8aeo+fRsWie47X64rK5owlvTAvfm30XASBpvqoRRnfPTHSJHn4IIOl3u2xS1YJu8rbEN09cAAD85//9I/Q897CpSJfVWrebt9sY81xazSoJS68aSfxJroRcAvi+EEICeFVKucd6ghBiB4AdALB06dIc3ZYUGq+baV7Ot7NSn364FQ21AcfrrX50AAhHJtDaWGPURbG7v4qkASS2dTVjLBY3fW4tBJYYd9yo8ggAA8M3sLLpX+GbJy5g9eJ5+Mm/3MLXf/Uex2exZqimcpc4fS+5DC9MvBExLr0cyJWQr5dSXhZCLARwQAhxVkp5RD9hWtz3AEBXV5e0G4SUPl5cAroPNl/3s3M7qKxOVRfFyQpWbh1gRlTjk++jq6UegDS1lgMSLpRj4SgGR97Fkrpq9IZGjFjzjasWmtxEe48PAZDY2rkYB4PD2NzRZDTRUKh5WzcyrWKtLyBOlRdJZZMTIZdSXp7+eVUI8fcA7gdwJPVVpNzJpw9WLRI3b92erkW+3BDB1UvmYfWS+bZWvF2Z2m1dzYYrpb3pk0Y8uQqVVNy1eB6OhaMYisbw6Y4mzKkSeOaRdhw+FwEgjRh55U4CYETEqMVFX3Ci4x8jeOUG/ubdGX+7XemDmQUgaPj96QohOlkLuRCiBsAsKeWN6T9/GsDXsp4Z8SXWJhL58sHqi4ReoEuFCTr55HVXjzX1/fUn78foRBxz51RBuVz0Z/rsfc3GZ7q1/u7QaFKMvHLbqEbMehu9sVgcL+4LIhafwsnBUdyauo2NKxtxX0s9Xj0cTsp2fWFLB25NvY/2hZ8EkLklzqzP8iUXFnkTgL8XicyJ2QC+JaX8fzkYl/gQp3Kxuca6SOiNnlNVGTS7ahIun5vx29h94ByU8FYHqozSutaYb90dA+i1YFaYMlf181Tcu/qp3DgPLK/DxpWNaF9Yiz1HE+GNdj7zuuoA5lTNwp6j5439gkwE2RrRQzEvH7IWcinleQD35GAuxKdYrXAg/6/+1o06t5Ed5gYTLUnp+X1DYzgWjhquECC5Low1Jb67Z8DI4Dx14RrWLJtvstitlrC1QNfe44NG1qqy2vWGz8q1otwyToLspuuTtSokKQ8YfkiyxmqFZysQmbgAnDZFU7WdU3MdnYjj1IUx9IZGcNfieXjozkaTK0RtqKoSuKrtXCw+ZSQ1JUQ+Ibi9oRH83XuX8djPLMLTms87Fp80YtFVRMwf/N8g9hwdxI4Hl5uyVlXykO5bV6LtJMhuuj5ZQzxJeUAhJ1mTays8lSC5aW6RqoiU01zXLJuPNcvqTKn0yr1yX0u9EV0CADfjU8ZPvSzA6iXzDL/3xbGb2HPkvOFT37mpHUByv9IfXboOAAheuWGaj3oLeOaRdrQvrEXwyg0jRPLNvot4YUuHsdBYr9nc0ZQU6qh/R7TEyw8KeQWTq82vXItDqoXBTXMLvbKhElE9yciuamJ3Twg7N7VrdVUGTNawbpHPDSQ6D6mf6p6JMdpwT/N8/DAUwbWbkxib+BjfPPGh4bfXS9ru2n8WOx5aASklbk3dRjgyjtbGhEWuim+tXjIfPWevGpUb165ocFzk9OYZQHIDDrtrSHlAIa9gSrXkaaqFwS7dXf+p/zkWnzTEVcV1b1+3PEXVxEljw3TnpnaMTSR6iqp2bsoi15s3q81Nq4/+yddv4J8+iuD8SMzUIciu6UbPj4dNJXatz6ASnJRPXV1nLaKlomVU31Kn7wXwtogz2qX0oZBXMIXamMwlVpG3a3ZhroI427SZqddL0cVQnZ9AAJhJu780dhND0RjaF140okb0phFPbWhNmpe+odnaWJs0P8Vbpy8jHJlAS0M12hfWGguDWhz2Hh9KipNX1+n1zZXYAhLdPaGkGHjrwmgXT2/Xacl6rteql1wACsOs9KeQckWJTzn8I9t7fBC79p/F3uODhmgCCeHZvq4FOze1G4W71HOrLMnN33gb710Y02qwDAAQ2LmpHevbGjAUjQEA/uGDf8au/Wfx3BunsbmjCc8/tspYDGYWgQStjbV46YlOHAwOG8lLan6ALo4Czz+2Co93Lsaeo4NGfRh1TnfPgFG4y+463QeuH3eal2JbV7Nxvbr2xX1B7Np/1jQH67leUONaxyO5hxY5KRNm2qHZWZDKNz0WS9RmeWFLB7Z1NePVI2GMTtwyil2pDcOtnXcYlvTe44PTUS2JGuXKD50oV3vOiGCxxpjrYYLtTYlknlMXrmF0Ip7kitHrmitL1rrJqtd1t5ad1d8yDgaH8dbpj9DdM4AfnL2K2bMEvvbLdxv+d8D8JqPG1MMfFdlY1X584/MrtMhJwVCWspOV6PU6/fj2dS14/rFV2L6uJcmC1C1DFZP94r4g6msC+Ma2TtTXzMFv/2KiEJbaMFSW9Jt9F7G1czHWLKvHzk1t2POFLot1Ki0/E8W1fv3PTuDytZtGSzhIOd2oYgS/+e1/xN7jg8YY+tuDHiL58g8GjLmo50i4TaqSRNX6lgFIbFzZiJODozgWjuKr3/3A9nt97o3TxpitjbVJb2jZWNXl9MZX6tAiJwXDzlJ2Y/E5+WidYsKt4+mWoYr3VhUM+4dvYHTiFvqHb+DhVQtNlq3KwFTRK9YORQCS6pQDwIv7gugNRdEbimLHQysQmD0LEALhyASW1VdPN5xIlO5Xvnb92fQ5qPR+a+q+0/dmtYL/w1/24eTgGGLxW9h94JzJkndT59ytVU1/eHGhkJO8ki7r081GmpOYuB1P34isrwmYytPqkR5qo/GpDa149XDYEDm7mG2FXYSNqo2yorEWpz8cwztDY2hfWIvnH1uF6Hgce44mCmSdunANv/f4XaZnsApi64bapM5Ceoik/px6lUQlpn/6613GgvTeh9dN4Y/KdfPMI+2OIuw2tLRUI6AqBQo5ySvpsj7dWHxOYmIXsZKJBQkA3T0hnLpwDX/8uZ8FAKOrkYoWUVa4KlF7Mz6JuYHZti3SWhtr8Zf/fi1ePRzGN3+YiHyZOx1FMjoRx9xAFfqGRtEbGsFbpz9CdaDKuFbPGlU+d/2Z1Oc7HlyRtAGpV0lUi5XK5lSVImPxSaNYWGtjDcKRCQD2NV68QH94caGQk7yS7h94Jo2crdZjJiUC9GtuxicBAL2hEcMX7NTVSC9Rq3j20TtNtVFU/fHo+Md4YHk97lky38gYVQW11Pkqdh1IuFK+84+XjTnZhVOeujAKYKZ3qI4KeXzmkXbTtaoujbLsd25qN5KcWhtr8Mwj7Vi7oiFlg+d0MGO0uFDISV7J9T9wq2irqIsdD5mrD6bDauUCwPq2BY4JNPqxWHwKPwyP4J2hMahNTjWvN/ouIhyZMBXdemTVQkeXhYp139aVaLB8YTSG1sYazLXxnb/ZdxG9oSg2rmzUkpLMjSxULRU7N4d1UVUul3eHRg13kt119H+XPhRy4iusYqQiOZSFaU2EcUJfYOz6azqNoSzqrZ134MV9QWztXGyUsv1UawOOhaOGlbt6yTzDnaEWmFQ1UF7Y0oH45Pu4a/E8bO28w5iTQq+/klxiIDnhyboQWRdVawEtu0QpAKZNX7vytxT64kMhJ77CrgkyAFO1wmzHdIO1rkl3zwB2bmpHYPYsHOqP4PC5q6gOzAaQKBNw5tL1JGtZlRA4ci6Clz9/L1oba/HQnY3Ytf8sGmz8/+qewIxPW70h7HhwBeYGZiWVAsjke1CWeSw+aXQ4UvHzduVvudFZfCjkxNfootW6wTkVHnBnOaY7xy5Zx9rgWUWV7Np/Fi0N1XhgeZ0hgvrCo2K7j4WjhkDavXHoESvqWr3ErvLnexXRdC6YWHzKFLmj0vetcKOz+FDISVmRyjp0YzmmO0d9rlw5q5dcNtwZultmdCKOfWc+QjgygSV1c5MKZ+0+0I/e0AgeWF6PtSvqDRG0RuLosePWRUsdVz+tbhu9UJhdid90Lhg921SP3LGOoxiL0cVSLCjkpKxIZR1aO/14vV4/fl9LPYBEm7juHvskp6//6j14+QcDRuEsM4mSAmtXNBhRLEq4DwaHER3/GHuODuIHZ4exdkWDcZU1OmZbV7NxX32zEoBtE2kvET52Lho92kYV7QJgaoKhfxf6NRT4/EEhJ2VFKv+w7tfWrUuFU5OKRKErYcSM6wlDq5fMN8Vzq7R3lQmqJx/p6Bus+jVKDFsaqgEAJwfHcHJwzBBjFUf+7Xc+xFA0ZtqAtC5Cqh67XSSOUwKSEzMCPmXsB1jj2O32KUYn4njmW+/hWDhqNLkmuYdCTopCoa00vTiUU4q7U5MK3fJU4qsLorU2iTXt3e5e+oJjzSKNxc/g5OAY7l06D10tDZg7Z1ZSbZehaCxpA9K6iKmEotGJOHYf6AcgsLXzDtN3ki4iRX8ua3Nr/Vyry0e/7lg4Ov2bAMkPFHJSFAod6aAEWSX52MVM27leElEhk1A1ytO5JewEPl3t7+RrEoI3p6oKT29oxd7jg3jlcBhz58zChjsXom9oDHctnofP3tdsuwFp16dULUZnLl0z3B9AIvqlvmaObUSKkz9dhVFaN5Wd6tyo708lRZHcQyEnRaHQkQ7W+9nd3871orIiAZiSd6zoImYVeLVA3NdSb1jARwcShbNUKr5+zT3N83FycBT3NM9PeiM4c+k6joWjeOjOxqQCXoqZVncJAb0Zn5wOT6wylaodi8WNBCa7wlnWcgGpCpY5HdO/P5I/KOSkKBQ6pdups5COUwSIO9/xpKljj/mzKVP898aVjWhfWIve0Aj6hkaTslGf3tCKhpoANnc04a3TH+ELa5eif3gc9yyZj8/e32zyQ9slGKk3Cr2wlh6eqMI0X9wXNETc3q1iLtGbLuLFLqFIHzMX7jRunNpDISdkGr0xs7W6oO5PVseAGat1x0PJRaysfmXrZuDA1XFbl4a+odrdM2DUFX9k1cIkK1xvXrF6yXxjzqr0wBfWLsP5kfGkSB3dl6+LuC6U1hK96SJerAlF+vdkd30mMPnIHgo5IRp2m3rquH3t7oS1OndOchErO3+5LsLWFHmrtekma3VbVzOODkRwqD9ilMpVIt7dM4D1bQvQG4riYNAcqZNqs9ZJrN26w5zOy4U7jclH9lDICdFwEjin43aNJRTp3EfWz60+absEILsx1iyrR28oapTKTZBYYDoWfRJrltUlFRRzihHPpACZm+emSyS/sNUbIRpO7cncHs+0nV2C5LZxTuOFI+N48vV38N6FRAVG1cdTsX3dcjz/2Co8/XAbqgNV6O4JObZrU/dILCQhDAzfQHdPCK8cDuPJ199BODIOwFvbN+u8c9WIOVfjlBu0yAnRyNZydOvDtbuP1bp38ssDM00kBkcmMBSNYX1bg0nIdas4lTtCv4fuy1+95DK+848f4cJoDKpRhRe3hvV7yMQlkq5tH5mBQk6IRrabaW6FJl1LOnWOU09N1USiua4aQ9EL6A1FkzZNFalcPOoerY012Nq52CglUB2YbdRGV/1N9c3gdAud9Xuw6+aUbqF08x2RBBRyQjSytfjcCo2b+9j55XUr9fUn70c4Mo7zI+PoWDQvbWKQ0z2OnIvgWDiK3/nO+3j58/cmpfvr1+p+/O3rWpKSm4CZGuupio4BzgulXk/G7juivz0ZCjkpWzL5B18oi8/Nfazn2KXUv3X6I/SGotDT31PFttvdo6ulHsfCUVM5Xef5Jfz3PwyP4NSFUfSGojg6EEFvKIpYfArVgaqUQu1mAXNbgTLVc1UaFHJStpTbP3jdDaLiz5Ww9oZGsO2V49jzhS4cDA5j1/6z2PHQCqNmeipmfOsyrfW7fd1yo9kEAGxc2Yjm+mr0hqK4OW2lA8lJVWOxRALSC1s60v5dpBN7+smToZCTsqUU/sFn6gZItdGnuzLGYnH0DY3hw9EYwpEJvLgviJee6ASQ6ECUENxgyoJY1obQTnVilKX+0hOdRkXIDXc24rfeOA0g0RDaWgxMJSvFJ29PF88KOlaE1OfjtbxupZOT8EMhxGeEEP1CiJAQ4su5GJOQbHEKGSwkericl9BEuzA79TytjbXGc711+iMcC0excdVCo3qiOm/7uuVGA4x04Xqq3Oyu/WenRTrBtq7mpIxVILEZun1dC17+wYBRiVE1hNZrq6v737V4njE/t9+Dl+8ru7BP/5O1RS6EqALwJwAeBXAJwLtCiLeklMFsxybE7+hvBV5cPenqvqhjN+OTAIC66jn42uNmS1dZz9ZOPjqq3vqpC9dsy83aRaroz6GiZ9QCApiteP3+6nOnFH4rXr6vcnOjeSUXrpX7AYSklOcBQAjx1wAeB0AhJxVPqnjuVG4XOxcFkFxp0K7Bg10XoVR1xtWGaEtDNT59108nlZt1iglX41vdNk4urXTRKFa8FDErBTdaMcmFkC8GoL+3XQLwgPUkIcQOADsAYOnSpTm4LSH+wi5O3I0Vaa0quLmjyWiSoboW6eiFtNI1jdjWlagXfuL8KE4OjmLunKqk85xiwp0sa6dFCIAnqzndYuZ0biVSsM1OKeUeAHsAoKurS6Y5nRDf4GQt5sqKtAqnavO2c1Ob7f2UxZv4GbStsKiP/eyjK7H7QD9ODo7iZnzSSNgBzDHh1mQer7Hw1mOpcCogVqkWdzpyIeSXAejf7pLpY4SUPanS6NNZ3F4yJYGEiOkNjq2t0/YeH0R3T8jUGzNdhUWFKg8Qi085WtB2ESxeY+HdWs0zzTGmjPZ6lWxxpyMXQv4ugHYhxHIkBPzXAHw+B+MSUvI4pdFbe4Smut6tu6G+JoAXtnTg1tT76Fg0z6Z1mrD8tKuwmCz2+nmJLkhVthZ0Ia1idY9YfLKiNzHdkrWQSyknhRBfBPAPAKoAvCal/CDrmRHiA1LV9dZ7hLq53g0Hg8PoDUXxYHtj0rjb17UkiXAyyWKvY2dB6y4VL2KaTSq9eWGxLxOcD/ya/p8TH7mU8nsAvpeLsQgpFdz8o3ZyL3j1f7sl1bhuxnIn9mZSvTWk+o5yERJY6E1Mv4YxMrOTEAey+UftJEBuLb5wZNxIaVcVCVONm+28UpFq8Uj1HaULQyxFq9evm6psLEGIA05Zjdmw9/jQdPbkUMrzVL3xF/cl0jGsmYv67/nOgHRqnhGOjKfcB3DKrC3l5hClkA2cCbTICXEgP6/1yV2A7NAzJoFky9ep6XK+MiB1K9oap55uH8BKKqvXaq0Xy3ov5bcGOyjkhBQQux6fdqLR2lhrKi6VSNyZMnpp6qGIq5fMc/3msLmjCSfOR9NWRLSiLwDqPve11BtjeiFdo4tUIY+Fwm++cgo5IQXETsTSiYYSekCiuydkLASrl8zD6iXzbbM7nTgYHMah/gjWrhh2bOis0P301ugclaCkj2VXGkA9n1vL1inUsdA+a7/5yinkhBSZdKJhV1fFGt7o1hXgRaCUn16Vnk23mWl1uSiysWyLlXrvt5R/CjkhRSadaFiLVOnHrCIKpBbMVPeyWtTPPNIOAIaf3m4sPTNVn+faFcOmxcKtZes3l0apQCEnJM9ku3HmVKRKFzo7S9vrfe02MVWKf11XwHYMq/CqJB7r3NPhtTIiMUMhJyTP5MrK9JoM5PW+dhZ1ujGsm7D1NQHHMgCp0BeRVN2MSoFSjGhhHDkheSZX8ehuYpz1OHF1X1X+NlXsuC5OegeidHOvrwmgOlCF7p6QFhduLgNgF7tuPbatqxmfam3Aof6IqUOR07PlitGJOHYf6MfuA+dcj1uKcfC0yAnJM4XcOLNzdbjpyONkebuZu/VNwVoGwBq6+GbfRcTik0ZDCzW+EKnrwLiN7vFiKeuNNaoDVa7+njIN4cwnFHJCyohU9b+drGqnSo1uhdEq9tbf9fs7dTZ6s+8iekMj2Liy0aaqo7vnyMSFlXANTQIQngqXuQ3hLBhSyoL/t2bNGkkIKQ1eeTskl31pn3zl7ZCr49Hxj+Urb4dkdPzjpGOhqzeSPkt3barjXsjFGKV0HzsA9EkbTaVFTkiFo6zSWHzK2LBUx/WfCjvL1y6G3EtLNr1srd6FyAu5cmGlexMpxRhzCjkhFU5iw3I2du0/a/ITWwUrVYhgqhhyLzi5R1I1lM51FIkfY9kp5IQQVxmfqQROF/1s/Mb6grD7QD8Age3rWlJa/HabqdmIut/S8wEKOSEEzta3LohuBS4XnYFePRw2RZO4yRp1E/fuZQ5+gkJOCEnCThDdClwuxNQaTZLK4tc/y4c1XYoJQFYo5ISUMZmKUDaCmM21+nzdZoXqeLWm3Xw/fvCZM7OTkDIm0yzEbDrlZHOtdb7hyDiefP0dhCPjnsfK5H52ZJKZm48s1FTQIiekjPHbxp11vtZSuvm+nx2Z+MwLbcWLRIx5Yenq6pJ9fX0Fvy8hxJ5S9QM7NaFOR7GfJ1/3F0KcklJ2WY/TIieElKwf2Nryzi3Ffp5CR75QyAkhWbtgim0BW/GbSylbuNlJSIWQagMumw1KoPRKu2b7PJlS6E1OBS1yQiqEfLob/GABF+KtoVguHQo5IRVCPsXWD9mQVpF1EvZsBL9YCxqFnJAKwQ9im090kR2diOO5N07bVmrMxqou1ndMISeElD1WK/vVw2Ec6o9g48rGJOvZD24iKxRyQkhJkStftj6O1crWxdp6Dz++uVDICSElRa42DK3lbfWffhTrVFDICSElRa5cG3qT5HITbiuMIyeElBS5igFXTZIPBodzNLPSJSshF0L8rhDishDi9PR/v5CriRFCiBucknAyqVroZfxSIheuld1Syv+Rg3EIIcQzTj71XLlTil23xQ30kRNCfE2+wwX9EI6YCx/5F4UQZ4QQrwkh6nIwHiGEuCbfdVWKVbfFC2mFXAhxUAjxvs1/jwP4UwCtADoBXAHwUopxdggh+oQQfZFIJFfzJ4SQiidnjSWEEC0A9kkp7053LhtLEEKId5waS2QbtbJI+/VXALyfzXiEEEK8k62P/L8LIf5JCHEGwEYAz+ZgToQQkjP8ED6YLVkJuZTy30gpf0ZKuVpKuVVKeSVXEyOElDeFEthUTS8KLfL5uh/DDwkhRaFQ8dmpwgcLHSOer/tRyAkhRaFQ8dmpEoO2dTUjFp9CLD6J0Ym4qxBDvaoiAE+VGvP1zBRyQkhRKIVCVvU1AVQHqrBr/1lUB2a7mo9uVQPwZGHn65kp5ISQglOI/plu8Wol251f7KxPVj8khBScVBuQhUJtPALwlLmpZ3qWStYnLXJCSMEphfolfiiG5RYKOSGk4JSCf7wUFpNcQSEnhFQkpbCY5Ar6yAkhxOdQyAkhxOdQyAkhxOdQyAkhxOdQyAkhxOdQyAkhxOdQyAkhxOdQyAkhFY/fm09QyAkhFU8p1H7JBmZ2EkIqHr+n61PICSEVj9/T9elaIYQQn0MhJ4QQn0MhJ4QQn0MhJ4QQn0MhJ4QQn0MhJ4QQn0MhJ4QQnyOklIW/qRARABcKdLsFAEYKdK98U07PApTX85TTswDl9Tzl9CzLpJSN1oNFEfJCIoTok1J2FXseuaCcngUor+cpp2cByut5yulZnKBrhRBCfA6FnBBCfE4lCPmeYk8gh5TTswDl9Tzl9CxAeT1POT2LLWXvIyeEkHKnEixyQggpayjkhBDicypCyIUQ/00IcUYIcVoI8X0hxB3FnlOmCCG+LoQ4O/08fy+EmF/sOWWDEGKbEOIDIcRtIYQvQ8SEEJ8RQvQLIUJCiC8Xez7ZIIR4TQhxVQjxfrHnki1CiGYhxCEhRHD6/7GdxZ5TvqgIIQfwdSnlaillJ4B9AH6nyPPJhgMA7pZSrgZwDsDzRZ5PtrwP4F8DOFLsiWSCEKIKwJ8AeAxAB4DPCSE6ijurrPgLAJ8p9iRyxCSA56SUHQDWAviPPv+7caQihFxK+RPt1xoAvt3hlVJ+X0o5Of3rCQBLijmfbJFS/lhK2V/seWTB/QBCUsrzUso4gL8G8HiR55QxUsojAEaLPY9cIKW8IqV8b/rPNwD8GMDi4s4qP1RMqzchxO8D+AKA6wA2Fnk6ueLfAvibYk+iwlkMQO/YewnAA0WaC3FACNEC4GcBnCzyVPJC2Qi5EOIggJ+2+egrUsrvSim/AuArQojnAXwRwFcLOkEPpHuW6XO+gsSr418Vcm6Z4OZ5CMkXQohaAH8L4D9Z3s7LhrIRcinlZpen/hWA76GEhTzdswghfgPAFgCbpA8SATz83fiRywD01utLpo+REkAIMQcJEf8rKeXfFXs++aIifORCiHbt18cBnC3WXLJFCPEZAP8FwFYpZazY8yF4F0C7EGK5ECIA4NcAvFXkOREAQggB4M8B/FhK+Y1izyefVERmpxDibwGsBHAbifK5T0spfWk1CSFCAD4BIDp96ISU8ukiTikrhBC/AuBlAI0ArgE4LaX8+aJOyiNCiF8A8EcAqgC8JqX8/eLOKHOEEN8G8DASpV+HAXxVSvnnRZ1Uhggh1gM4CuCfkPi3DwD/VUr5veLNKj9UhJATQkg5UxGuFUIIKWco5IQQ4nMo5IQQ4nMo5IQQ4nMo5IQQ4nMo5IQQ4nMo5IQQ4nP+P9ZS1kLT2iupAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a75d8682",
   "metadata": {},
   "source": [
    "生成第二个特征features[:, 1]和labels的散点图，可以直观地观察到两者之间的线性关系。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "316bd3a4",
   "metadata": {},
   "source": [
    "# 读取数据集\n",
    "回想一下，训练模型时要对数据集进行遍历，每次抽取一小批量样本，并使用它们来更新我们的模型。 由于这个过程是训练机器学习算法的基础，所以有必要定义一个函数，该函数能打乱数据集中的样本并以小批量方式获取数据。\n",
    "\n",
    "在下面的代码中，我们定义一个data_iter函数， 该函数接收批量大小、特征矩阵和标签向量作为输入，生成大小为batch_size的小批量。每个小批量包含一组特征和标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c657651c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:09:23.625614Z",
     "start_time": "2021-09-07T00:09:23.610613Z"
    }
   },
   "outputs": [],
   "source": [
    "def data_iter(batch_size, features, labels):\n",
    "    #多少个样本\n",
    "    num_examples = len(features)\n",
    "    #生成0-num_examples的索引值\n",
    "    indices = list(range(num_examples))\n",
    "    \n",
    "    # 这些样本是随机读取的，没有特定的顺序\n",
    "    random.shuffle(indices)\n",
    "    \n",
    "    for i in range(0, num_examples, batch_size):\n",
    "        batch_indices = torch.tensor(\n",
    "            indices[i: min(i + batch_size, num_examples)])\n",
    "        yield features[batch_indices], labels[batch_indices]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21580121",
   "metadata": {},
   "source": [
    "通常，我们使用合理大小的小批量来利用GPU硬件的优势，因为GPU在并行处理方面表现出色。每个样本都可以并行地进行模型计算，且每个样本损失函数的梯度也可以被并行地计算，GPU可以在处理几百个样本时，所花费的时间不比处理一个样本时多太多。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b55472fd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:07:14.990256Z",
     "start_time": "2021-09-07T00:07:14.978256Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indices = list(range(10))\n",
    "indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a5ee8eb7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:08:16.450772Z",
     "start_time": "2021-09-07T00:08:16.439771Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[9, 3, 1, 0, 6, 5, 2, 7, 8, 4]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.shuffle(indices)\n",
    "indices"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c6ee9c9",
   "metadata": {},
   "source": [
    "取一次数据，看下数据的样子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "616e5adb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:09:26.537781Z",
     "start_time": "2021-09-07T00:09:26.196761Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.5771,  0.0579],\n",
      "        [ 0.1281,  1.6061],\n",
      "        [ 1.0820,  0.0167],\n",
      "        [-0.4927,  2.1243],\n",
      "        [-0.3875, -0.0284],\n",
      "        [-0.6186,  0.4223],\n",
      "        [ 0.4569,  0.9319],\n",
      "        [ 0.8345, -0.0568],\n",
      "        [ 0.9522, -0.0529],\n",
      "        [ 1.0858, -1.0753]]) \n",
      " tensor([[ 2.8577],\n",
      "        [-0.9868],\n",
      "        [ 6.3033],\n",
      "        [-4.0099],\n",
      "        [ 3.5158],\n",
      "        [ 1.5222],\n",
      "        [ 1.9461],\n",
      "        [ 6.0571],\n",
      "        [ 6.2885],\n",
      "        [10.0274]])\n"
     ]
    }
   ],
   "source": [
    "batch_size = 10\n",
    "\n",
    "for X, y in data_iter(batch_size, features, labels):\n",
    "    print(X, '\\n', y)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae1affbe",
   "metadata": {},
   "source": [
    "当我们运行迭代时，我们会连续地获得不同的小批量，直至遍历完整个数据集。 上面实现的迭代对于教学来说很好，但它的执行效率很低，可能会在实际问题上陷入麻烦。 例如，它要求我们将所有数据加载到内存中，并执行大量的随机内存访问。 在深度学习框架中实现的内置迭代器效率要高得多，它可以处理存储在文件中的数据和通过数据流提供的数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b00599e7",
   "metadata": {},
   "source": [
    "# 初始化模型参数"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7403ad85",
   "metadata": {},
   "source": [
    "在我们开始用小批量随机梯度下降优化我们的模型参数之前，我们需要先有一些参数。 在下面的代码中，我们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重，并将偏置初始化为0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "bfdf59d3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:39:40.938558Z",
     "start_time": "2021-09-07T00:39:40.591538Z"
    }
   },
   "outputs": [],
   "source": [
    "w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)\n",
    "b = torch.zeros(1, requires_grad=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d44f0798",
   "metadata": {},
   "source": [
    "在初始化参数之后，我们的任务是更新这些参数，直到这些参数足够拟合我们的数据。 每次更新都需要计算损失函数关于模型参数的梯度。有了这个梯度，我们就可以向减小损失的方向更新每个参数。 因为手动计算梯度很枯燥而且容易出错，所以没有人会手动计算梯度。我们使用 torch的自动微分来计算梯度。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60807223",
   "metadata": {},
   "source": [
    "# 定义模型  \n",
    "接下来，我们必须定义模型，将模型的输入和参数同模型的输出关联起来。 要计算线性模型的输出，我们只需计算输入特征 X 和模型权重 w 的矩阵-向量乘法后加上偏置 b 。注意，上面的 Xw 是一个向量，而 b 是一个标量。当我们用一个向量加一个标量时，标量会被加到向量的每个分量上。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "39d56c8e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:39:38.710431Z",
     "start_time": "2021-09-07T00:39:38.702430Z"
    }
   },
   "outputs": [],
   "source": [
    "def linreg(X, w, b):  #@save\n",
    "    \"\"\"线性回归模型。\"\"\"\n",
    "    return torch.matmul(X, w) + b"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72e2fb18",
   "metadata": {},
   "source": [
    "# 定义损失函数\n",
    "因为要更新模型。需要计算损失函数的梯度，所以我们应该先定义损失函数。 这里我们使用平方损失函数。 在实现中，我们需要将真实值y的形状转换为和预测值y_hat的形状相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a2138be0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:39:37.253348Z",
     "start_time": "2021-09-07T00:39:37.245347Z"
    }
   },
   "outputs": [],
   "source": [
    "def squared_loss(y_hat, y):  #@save\n",
    "    \"\"\"均方损失。\"\"\"\n",
    "    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f01d4fb",
   "metadata": {},
   "source": [
    "# 定义优化算法\n",
    "\n",
    "在每一步中，使用从数据集中随机抽取的一个小批量，然后根据参数计算损失的梯度。接下来，朝着减少损失的方向更新我们的参数。 下面的函数实现小批量随机梯度下降更新。该函数接受模型参数集合、学习速率和批量大小作为输入。每一步更新的大小由学习速率lr决定。 因为我们计算的损失是一个批量样本的总和，所以我们用批量大小（batch_size）来归一化步长，这样步长大小就不会取决于我们对批量大小的选择。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "bd2f6e91",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:39:35.322237Z",
     "start_time": "2021-09-07T00:39:35.312237Z"
    }
   },
   "outputs": [],
   "source": [
    "def sgd(params, lr, batch_size):  #@save\n",
    "    \"\"\"小批量随机梯度下降。\"\"\"\n",
    "    with torch.no_grad():\n",
    "        for param in params:\n",
    "            param -= lr * param.grad / batch_size\n",
    "            param.grad.zero_()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "052f17eb",
   "metadata": {},
   "source": [
    "# 训练\n",
    "现在我们已经准备好了模型训练所有需要的要素，可以实现主要的训练过程部分了。 理解这段代码至关重要，因为在整个深度学习的职业生涯中，你会一遍又一遍地看到几乎相同的训练过程。 在每次迭代中，我们读取一小批量训练样本，并通过我们的模型来获得一组预测。 计算完损失后，我们开始反向传播，存储每个参数的梯度。最后，我们调用优化算法sgd来更新模型参数。\n",
    "\n",
    "概括一下，我们将执行以下循环：\n",
    "\n",
    "初始化参数\n",
    "\n",
    "重复，直到完成\n",
    "\n",
    "计算梯度 \n",
    "\n",
    "更新参数 (w,b)\n",
    "\n",
    "在每个迭代周期（epoch）中，我们使用data_iter函数遍历整个数据集，并将训练数据集中所有样本都使用一次（假设样本数能够被批量大小整除）。这里的迭代周期个数num_epochs和学习率lr都是超参数，分别设为3和0.03。设置超参数很棘手，需要通过反复试验进行调整。 我们现在忽略这些细节。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e289a23b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:41:48.590860Z",
     "start_time": "2021-09-07T00:41:47.530799Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.000050\n",
      "epoch 2, loss 0.000050\n",
      "epoch 3, loss 0.000050\n",
      "epoch 4, loss 0.000050\n",
      "epoch 5, loss 0.000050\n",
      "epoch 6, loss 0.000050\n",
      "epoch 7, loss 0.000050\n",
      "epoch 8, loss 0.000050\n",
      "epoch 9, loss 0.000050\n",
      "epoch 10, loss 0.000050\n"
     ]
    }
   ],
   "source": [
    "lr = 0.03\n",
    "num_epochs = 10\n",
    "net = linreg\n",
    "loss = squared_loss\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for X, y in data_iter(batch_size, features, labels):\n",
    "        l = loss(net(X, w, b), y)  # `X`和`y`的小批量损失\n",
    "        # 因为`l`形状是(`batch_size`, 1)，而不是一个标量。`l`中的所有元素被加到一起，\n",
    "        # 并以此计算关于[`w`, `b`]的梯度\n",
    "        l.sum().backward()\n",
    "        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数\n",
    "    with torch.no_grad():\n",
    "        train_l = loss(net(features, w, b), labels)\n",
    "        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b408a67e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:41:53.775156Z",
     "start_time": "2021-09-07T00:41:53.761155Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w的估计误差: tensor([ 6.2108e-05, -7.2336e-04], grad_fn=<SubBackward0>)\n",
      "b的估计误差: tensor([0.0005], grad_fn=<RsubBackward1>)\n"
     ]
    }
   ],
   "source": [
    "print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
    "print(f'b的估计误差: {true_b - b}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "e94e5baf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:41:55.703266Z",
     "start_time": "2021-09-07T00:41:55.691266Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.9999],\n",
      "        [-3.3993]], requires_grad=True)\n",
      "tensor([ 2.0000, -3.4000])\n"
     ]
    }
   ],
   "source": [
    "print(w)\n",
    "print(true_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "0c9c6437",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-09-07T00:42:01.121576Z",
     "start_time": "2021-09-07T00:42:01.110576Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([4.1995], requires_grad=True)\n",
      "4.2\n"
     ]
    }
   ],
   "source": [
    "print(b)\n",
    "print(true_b)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5rc1"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
